import numpy as np
from sklearn.model_selection import train_test_split

#collist=['holidayvibes','gamenight','happyholidays','carsoftiktok','fallguysmoments','homecooked','veteransday','youwantmore','interiordesign','coldweather','wildanimals','mycostume','meleaving','mypfp','catchphrases','watchmegrow','holidaycrafts','growupwithme','clingypet','happyhanukkah','lunarnewyear',
#         'tabletop','comfortfood','selfimprovement','2021affirmations','perfectmatch','givingszn','holidaycountdown','bakingszn','holidaymusic','familyimpression','inkdrawing','WeekendVibes','productivity','recordsday']#'holidaycountdown',,,,,'onlinedating','festivefashion','raisedby',
#collist=['holidayvibes']#,'gamenight','happyholidays','carsoftiktok','fallguysmoments']#,'interiordesign','homecooked','veteransday','youwantmore','interiordesign','coldweather','wildanimals','mycostume','meleaving','mypfp','catchphrases','watchmegrow','holidaycrafts','growupwithme','clingypet','happyhanukkah','lunarnewyear',
        #'tabletop','comfortfood','selfimprovement','2021affirmations','perfectmatch']#,'givingszn','holidaycountdown','bakingszn','holidaymusic','familyimpression','inkdrawing','WeekendVibes','recordsday','productivity','smallbusiness']
#collist=['holidayvibes']#,'gamenight','happyholidays','carsoftiktok','fallguysmoments','interiordesign','homecooked','veteransday','youwantmore','interiordesign','coldweather','wildanimals','mycostume','meleaving','mypfp','catchphrases','watchmegrow','holidaycrafts','growupwithme','clingypet','happyhanukkah','lunarnewyear',
      # 'tabletop','comfortfood','selfimprovement','2021affirmations','perfectmatch','givingszn','holidaycountdown','bakingszn','holidaymusic','familyimpression','inkdrawing','WeekendVibes','recordsday','productivity','smallbusiness'
      # ,'falldiy','whenwewereyounger','yellow','ComingOfAge','artmas','gaminglife','gamingsetup','hellowinter','planttiktok','housetour','neonshadow','homeoffice','raisedby','makeitvogue','foodtiktok','valentinesday','yougotthis','stemlife']
collist=[ 'holidayvibes','gamenight','happyholidays','carsoftiktok','fallguysmoments','interiordesign','homecooked','veteransday','youwantmore','interiordesign','coldweather','wildanimals','mycostume','meleaving','mypfp','catchphrases','watchmegrow',
        'holidaycrafts','growupwithme','clingypet','happyhanukkah','lunarnewyear','tabletop','comfortfood','selfimprovement','2021affirmations','perfectmatch','givingszn','holidaycountdown','bakingszn','holidaymusic','familyimpression','inkdrawing','WeekendVibes','recordsday','productivity','smallbusiness'
      ,'falldiy','whenwewereyounger','yellow','ComingOfAge','artmas','gaminglife','gamingsetup','hellowinter','planttiktok','housetour','neonshadow','homeoffice','raisedby','makeitvogue','foodtiktok','valentinesday','yougotthis','stemlife']


t=[]
t_sticker=[]
img_emb = []
yamnet = []
text = []
text_sticker = []
labels = []
edits = []
img_emb_new=[]
hts=[]
ids=[]
version_num='d35'
# for col in collist:
#     text.extend( np.load('E:\\data_pi\\text_embed_'+col+'.npy'))
#     img_emb .extend( np.load('E:\\data_pi\\image_embed_'+col+'.npy'))
#     yamnet .extend( np.load('E:\\data_pi\\yamnet_embed_'+col+'.npy'))
#     edits .extend( np.load('E:\\data_pi\\edit_embed_'+col+'.npy'))
#     labels.extend( np.load('E:\\data_pi\\label_'+col+'.npy'))

htc=0
for col in collist:
    htn = np.zeros(len(collist))
    htn[htc]=1
    htc+=1
    t.extend(np.load('E:\\data_pi\\text_4h-15h_new_' + col + '.npy'))
    for it in range(len(np.load('E:\\data_pi\\text_4h-15h_new_' + col + '.npy'))):
        hts.append(htn)
    ids.extend(np.load('E:\\data_pi\\ids_4h-15h_new_' + col + '.npy'))
    text.extend( np.load('E:\\data_pi\\text_embed_4h-15h_new_'+col+'.npy'))
    t_sticker.extend(np.load('E:\\data_pi\\text_sticker_4h-15h_new_' + col + '.npy'))
    text_sticker.extend(np.load('E:\\data_pi\\text_sticker_embed_4h-15h_new_' + col + '.npy'))

    img_emb .extend( np.load('E:\\data_pi\\image_embed_4h-15h_new_'+col+'.npy'))
    yamnet .extend( np.load('E:\\data_pi\\yamnet_embed_4h-15h_new_'+col+'.npy'))
    edits .extend( np.load('E:\\data_pi\\edit_embed_4h-15h_new_'+col+'.npy'))
    labels.extend( np.load('E:\\data_pi\\label_4h-15h_new_'+col+'.npy'))
    img_emb_new.extend(np.load('E:\\data_pi\\image_embed_new_4h-15h_new_'+col+'.npy'))

train_t, test_t=train_test_split(t,test_size=0.2,random_state=42)
train_ids,test_ids=train_test_split(ids,test_size=0.2,random_state=42)
train_hts,test_hts=train_test_split(hts,test_size=0.2,random_state=42)

train_t_sticker, test_t_sticker=train_test_split(t_sticker,test_size=0.2,random_state=42)
train_img_emb,test_img_emb=train_test_split(img_emb,test_size=0.2,random_state=42)
train_yamnet, test_yamnet=train_test_split(yamnet,test_size=0.2,random_state=42)
train_text, test_text=train_test_split(text,test_size=0.2,random_state=42)
train_text_sticker, test_text_sticker=train_test_split(text_sticker,test_size=0.2,random_state=42)
train_labels,test_labels=train_test_split(labels,test_size=0.2,random_state=42)
train_edits,test_edits=train_test_split(edits,test_size=0.2,random_state=42)

train_img_emb_new,test_img_emb_new=train_test_split(img_emb_new,test_size=0.2,random_state=42)
np.save('E:\\data_pi\\train_image_embed_prob_'+version_num, np.array(train_img_emb))
np.save('E:\\data_pi\\test_image_embed_prob_'+version_num, np.array(test_img_emb))
np.save('E:\\data_pi\\train_yamnet_embed_'+version_num, np.array(train_yamnet))
np.save('E:\\data_pi\\test_yamnet_embed_'+version_num, np.array(test_yamnet))
np.save('E:\\data_pi\\train_label_'+version_num, np.array(train_labels))
np.save('E:\\data_pi\\test_label_'+version_num, np.array(test_labels))
np.save('E:\\data_pi\\train_edit_embed_'+version_num, train_edits)
np.save('E:\\data_pi\\test_edit_embed_'+version_num, test_edits)

np.save('E:\\data_pi\\train_text_embed_'+version_num, train_text)
np.save('E:\\data_pi\\test_text_embed_'+version_num, test_text)
np.save('E:\\data_pi\\train_text_sticker_embed_'+version_num, train_text_sticker)
np.save('E:\\data_pi\\test_text_sticker_embed_'+version_num, test_text_sticker)

np.save('E:\\data_pi\\train_image_embed_'+version_num, np.array(train_img_emb_new))
np.save('E:\\data_pi\\test_image_embed_'+version_num, np.array(test_img_emb_new))
np.save('E:\\data_pi\\train_text_'+version_num, train_t)
np.save('E:\\data_pi\\test_text_'+version_num, test_t)
np.save('E:\\data_pi\\train_text_sticker_'+version_num, train_t_sticker)
np.save('E:\\data_pi\\test_text_sticker_'+version_num, test_t_sticker)

np.save('E:\\data_pi\\train_ids_'+version_num, train_ids)
np.save('E:\\data_pi\\test_ids_'+version_num, test_ids)
np.save('E:\\data_pi\\train_hts_'+version_num, train_hts)
np.save('E:\\data_pi\\test_hts_'+version_num, test_hts)